"""
Tests for BaseEmbeddingProvider.
"""

import pytest
from unittest.mock import Mock

from src.local_deep_research.embeddings.providers.base import (
    BaseEmbeddingProvider,
)


class TestBaseEmbeddingProviderMetadata:
    """Tests for BaseEmbeddingProvider class metadata."""

    def test_provider_name_default(self):
        """Default provider name is 'base'."""
        assert BaseEmbeddingProvider.provider_name == "base"

    def test_provider_key_default(self):
        """Default provider key is 'BASE'."""
        assert BaseEmbeddingProvider.provider_key == "BASE"

    def test_requires_api_key_default(self):
        """Default requires_api_key is False."""
        assert BaseEmbeddingProvider.requires_api_key is False

    def test_supports_local_default(self):
        """Default supports_local is False."""
        assert BaseEmbeddingProvider.supports_local is False

    def test_default_model_is_none(self):
        """Default model is None."""
        assert BaseEmbeddingProvider.default_model is None


class TestBaseEmbeddingProviderAbstract:
    """Tests for abstract method enforcement."""

    def test_cannot_instantiate_directly(self):
        """Cannot instantiate abstract base class."""
        with pytest.raises(TypeError):
            BaseEmbeddingProvider()

    def test_subclass_must_implement_create_embeddings(self):
        """Subclass must implement create_embeddings."""

        class IncompleteProvider(BaseEmbeddingProvider):
            @classmethod
            def is_available(cls, settings_snapshot=None):
                return True

        with pytest.raises(TypeError):
            IncompleteProvider()

    def test_subclass_must_implement_is_available(self):
        """Subclass must implement is_available."""

        class IncompleteProvider(BaseEmbeddingProvider):
            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

        with pytest.raises(TypeError):
            IncompleteProvider()


class TestBaseEmbeddingProviderMethods:
    """Tests for non-abstract methods."""

    @pytest.fixture
    def concrete_provider(self):
        """Create a concrete provider for testing."""

        class TestProvider(BaseEmbeddingProvider):
            provider_name = "Test Provider"
            provider_key = "TEST"
            requires_api_key = True
            supports_local = False
            default_model = "test-model"

            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

            @classmethod
            def is_available(cls, settings_snapshot=None):
                return True

        return TestProvider

    def test_get_available_models_default(self, concrete_provider):
        """Default get_available_models returns empty list."""
        result = concrete_provider.get_available_models()
        assert result == []

    def test_get_model_info_default(self, concrete_provider):
        """Default get_model_info returns None."""
        result = concrete_provider.get_model_info("any-model")
        assert result is None

    def test_validate_config_available(self, concrete_provider):
        """validate_config returns True when available."""
        is_valid, error = concrete_provider.validate_config()
        assert is_valid is True
        assert error is None

    def test_validate_config_not_available(self):
        """validate_config returns False when not available."""

        class UnavailableProvider(BaseEmbeddingProvider):
            provider_name = "Unavailable"

            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

            @classmethod
            def is_available(cls, settings_snapshot=None):
                return False

        is_valid, error = UnavailableProvider.validate_config()
        assert is_valid is False
        assert error is not None
        assert "not available" in error

    def test_get_provider_info(self, concrete_provider):
        """get_provider_info returns expected metadata."""
        info = concrete_provider.get_provider_info()

        assert info["name"] == "Test Provider"
        assert info["key"] == "TEST"
        assert info["requires_api_key"] is True
        assert info["supports_local"] is False
        assert info["default_model"] == "test-model"


class TestBaseEmbeddingProviderSubclassing:
    """Tests for subclassing BaseEmbeddingProvider."""

    def test_subclass_can_override_attributes(self):
        """Subclass can override class attributes."""

        class CustomProvider(BaseEmbeddingProvider):
            provider_name = "Custom"
            provider_key = "CUSTOM"
            requires_api_key = True
            supports_local = True
            default_model = "custom-model"

            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

            @classmethod
            def is_available(cls, settings_snapshot=None):
                return True

        assert CustomProvider.provider_name == "Custom"
        assert CustomProvider.provider_key == "CUSTOM"
        assert CustomProvider.requires_api_key is True
        assert CustomProvider.supports_local is True
        assert CustomProvider.default_model == "custom-model"

    def test_subclass_can_override_get_available_models(self):
        """Subclass can override get_available_models."""

        class CustomProvider(BaseEmbeddingProvider):
            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

            @classmethod
            def is_available(cls, settings_snapshot=None):
                return True

            @classmethod
            def get_available_models(cls, settings_snapshot=None):
                return [
                    {"value": "model-1", "label": "Model 1"},
                    {"value": "model-2", "label": "Model 2"},
                ]

        models = CustomProvider.get_available_models()
        assert len(models) == 2
        assert models[0]["value"] == "model-1"

    def test_subclass_can_override_get_model_info(self):
        """Subclass can override get_model_info."""

        class CustomProvider(BaseEmbeddingProvider):
            @classmethod
            def create_embeddings(
                cls, model=None, settings_snapshot=None, **kwargs
            ):
                return Mock()

            @classmethod
            def is_available(cls, settings_snapshot=None):
                return True

            @classmethod
            def get_model_info(cls, model):
                if model == "custom-model":
                    return {
                        "dimensions": 768,
                        "description": "Custom embedding model",
                    }
                return None

        info = CustomProvider.get_model_info("custom-model")
        assert info is not None
        assert info["dimensions"] == 768

        info_none = CustomProvider.get_model_info("unknown")
        assert info_none is None
